import os
import numpy as np
from easy_io import write_pkl_file, read_pkl_file
import global_config


# fold_list = np.load('/data_4t/Kaggle/fold_list.npy')
# fold_list = np.load('/data_4t/Kaggle/fold_list_no_error.npy')
# label_dict = np.load('/data_4t/Kaggle/lidc&kaggle/label_dict3.npy').item()
#
# nb_total_scans_dict, nb_total_labels_dict = {}, {}
#
# for i in range(len(fold_list)):
#     scans = fold_list[i]
#     nb_total_scans_dict[i] = len(scans)
#     nb_total_labels_dict[i] = sum(len(label_dict[s]) for s in scans)
#
# os.makedirs(global_config.data_folder, exist_ok=True)
# write_pkl_file(os.path.join(global_config.data_folder, 'lidc_kaggle_info_v1.pkl'), (nb_total_scans_dict, nb_total_labels_dict))


# fold_list = np.load('/data_4t/Kaggle/fold_list_no_error.npy')
# label_dict = read_pkl_file('/ssd_1t/huzq/kaggle_data/lidc_kaggle_label_dict_3_backup.pkl')
#
# nb_total_scans_dict, nb_total_labels_dict, nb_total_maligns_dict = {}, {}, {}
#
# for i in range(len(fold_list)):
#     scans = fold_list[i]
#     nb_total_scans_dict[i] = len(scans)
#     nb_total_labels_dict[i] = sum(len(label_dict[s]) for s in scans)
#     nb_total_maligns_dict[i] = sum(sum(c['malign'] > 3.0 if c['source'] == 'lidc' else c['malign'] > 6.0 for c in label_dict[s]) for s in scans)
#
# os.makedirs(global_config.data_folder, exist_ok=True)
# write_pkl_file(os.path.join(global_config.data_folder, 'lidc_kaggle_info_v2.pkl'),
#                (nb_total_scans_dict, nb_total_labels_dict, nb_total_maligns_dict))

fold_list = np.load('/data_4t/Kaggle/fold_list_no_error.npy')
# label_dict = read_pkl_file('/ssd_1t/huzq/kaggle_data/lidc_kaggle_label_dict_3_backup.pkl')
lidc_3 = np.load('/data_4t/Kaggle/backup/lidc/label_dict_3_0329.npy').item()
kaggle = np.load('/data_4t/Kaggle/backup/kaggle/label_dict_0325.npy').item()
label_dict_3 = dict()
label_dict_3.update(lidc_3)
label_dict_3.update(kaggle)
label_dict = label_dict_3

nb_total_scans_dict, nb_total_labels_dict, nb_total_maligns_dict = {}, {}, {}

for i in range(len(fold_list)):
    scans = fold_list[i]
    nb_total_scans_dict[i] = len(scans)
    nb_total_labels_dict[i] = sum(len(label_dict[s]) for s in scans)
    nb_total_maligns_dict[i] = sum(sum(c['malign'] > 6.0 if 'malign' in c else c['attrs']['malignancy'] > 3.0 for c in label_dict[s]) for s in scans)

os.makedirs(global_config.data_folder, exist_ok=True)
write_pkl_file(os.path.join(global_config.data_folder, 'lidc_kaggle_info_v3.pkl'),
               (nb_total_scans_dict, nb_total_labels_dict, nb_total_maligns_dict))
